In [1]:
# Initialize Notebook
from IPython.core.display import HTML,Image
#%run ../library/v1.0.5/init.ipy
HTML('''<script> code_show=true;  function code_toggle() {  if (code_show){  $('div.input').hide();  } else {  $('div.input').show();  }  code_show = !code_show }  $( document ).ready(code_toggle); </script> <form action="javascript:code_toggle()"><input type="submit" value="Toggle Code"></form>''')
Out[1]:
In [2]:
import gc, argparse, sys, os, errno
%pylab inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
#sns.set()
#sns.set_style('whitegrid')
import h5py
from PIL import Image
import os
from tqdm import tqdm_notebook as tqdm
import scipy
import sklearn
from scipy.stats import pearsonr
import warnings
warnings.filterwarnings('ignore')
from scipy.io import loadmat
import IPython.display as ipd
import IPython
import librosa.display
import librosa
from pystoi import stoi
Populating the interactive namespace from numpy and matplotlib
/scratch/xc1490/anaconda3/lib/python3.7/site-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm
In [3]:
ls
attention_mask.mat
ecog.mat
gradient.mat
gt.wav
merge.wav
myCNN2_fineturn_mergenet_16k_vae_v2_20200414-121545label_test.tsv
pred.wav
result_NY717.4areas_files/
result_NY717.4areas.html
result_NY717.4areas.ipynb
spectrogram_GT.mat
spectrogram_prediction.mat
wav/
waveform_GT.mat
In [4]:
select_word = np.loadtxt([i for i in os.listdir('.') if i[-3:]=='tsv'][0],dtype='str')

Spectrogram

  • spec_gt: upper
  • spec_pred: bottom
In [5]:
spec_gt = loadmat('spectrogram_GT.mat')['GT_STFT_test_spkr']
spec_pred = loadmat('spectrogram_prediction.mat')['pred_STFT_test']
spec_concat = np.concatenate((numpy.swapaxes(spec_gt,2,1), numpy.swapaxes(spec_pred,2,1)),\
                             axis=1)
In [6]:
row_nums = 18
col_nums = 10
fig,ax=plt.subplots(row_nums,col_nums,figsize=(col_nums*2,row_nums*1.5))
cmap = cm.coolwarm
for i in range(row_nums):
    for j in range(col_nums):
        ax[i,j].imshow(spec_concat[i*col_nums+j] ,cmap=cmap)
        ax[i,j].set_title(select_word[i*col_nums+j])
#fig.suptitle('Spectrogram Demo', fontsize=14)
#fig.subplots_adjust(top=1)
fig.tight_layout()

Waveform

In [7]:
wave_gt = librosa.load('gt.wav',sr=16000)[0]
wave_pred = librosa.load('pred.wav',sr=16000)[0]
wave_merge = librosa.load('merge.wav',sr=16000)[0]

ground truth

In [8]:
display(ipd.Audio(wave_gt,rate=16000))

reconstructed audio

In [9]:
display(ipd.Audio(wave_pred,rate=16000))

merged audio

In [10]:
display(ipd.Audio(wave_merge,rate=16000))
In [11]:
interval = 16384
row_nums = 18
col_nums = 10
fig,ax=plt.subplots(row_nums*2,col_nums,figsize=(col_nums*2,row_nums*1.5))
for i in range(row_nums):
    for j in range(col_nums):
        ax[i*2,j].set_title(select_word[i*col_nums+j])
        ax[i*2,j].plot(wave_gt[(i*col_nums+j)*interval:(i*col_nums+j+1)*interval])
        ax[i*2+1,j].plot(wave_pred[(i*col_nums+j)*interval:(i*col_nums+j+1)*interval])
        #librosa.display.waveplot(wave_gt[(i*col_nums+j)*interval:(i*col_nums+j+1)*interval], sr=16000,ax=ax[i*2,j])
        #librosa.display.waveplot(wave_pred[(i*col_nums+j)*interval:(i*col_nums+j+1)*interval], sr=16000,ax=ax[i*2+1,j])
        ax[i*2,j].axis('off')
        ax[i*2+1,j].axis('off')
fig.tight_layout()

Visualization

Attention

In [12]:
attention = loadmat('attention_mask.mat')['ams_test'][:,:,:,:,0]
average_mask = attention.mean(axis=0)
average_mask.shape
Out[12]:
(36, 15, 15)
In [13]:
row_nums = 3
col_nums = 12
fig,ax=plt.subplots(row_nums,col_nums,figsize=(col_nums*3,row_nums*3))
for i in range(row_nums):
    for j in range(col_nums):
        im = ax[i,j].imshow(average_mask[i*col_nums+j],cmap=cm.hot,vmin=mean(\
                    average_mask[average_mask!=0]),\
                   vmax=np.max(average_mask))
        ax[i,j].axis('off')
        
        
fig.subplots_adjust(right=0.84)
cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.7])
fig.colorbar(im, cbar_ax)
Out[13]:
<matplotlib.colorbar.Colorbar at 0x7ff58489c860>

Ecog

In [14]:
ecog = loadmat('ecog.mat')['GT_STFT_test_ecog'][0,:,:,:].reshape(180,176, 15,15)[:,16:-16,:,:]
ecog_ = np.zeros([36,15,15])
for i in range(36):
    ecog_[i] = np.mean(np.max(ecog[:,i*4:(i+1)*4,:,:],1),axis=0)

row_nums = 3
col_nums = 12
fig,ax=plt.subplots(row_nums,col_nums,figsize=(col_nums*3,row_nums*3))
for i in range(row_nums):
    for j in range(col_nums):
        im = ax[i,j].imshow(ecog_[i*col_nums+j],cmap=cm.hot,vmin=mean(\
                    ecog_[ecog_!=0]),\
                   vmax=np.max(ecog_))
        ax[i,j].axis('off')
        
fig.subplots_adjust(right=0.84)
cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.7])
fig.colorbar(im, cbar_ax)
Out[14]:
<matplotlib.colorbar.Colorbar at 0x7ff584b1be80>

Gradient

In [15]:
gradient = loadmat('gradient.mat')['grad_loss2inp_test'].reshape(-1,\
                        176, 225).reshape(-1,176, 15,15)[:,16:-16,:,:]
gradient[gradient<0] = 0
gradient = np.abs(gradient)
gradient_ = np.zeros([36,15,15])
for i in range(36):
    gradient_[i] = np.mean(np.max(gradient[:,i*4:(i+1)*4,:,:],1),axis=0)

row_nums = 3
col_nums = 12
fig,ax=plt.subplots(row_nums,col_nums,figsize=(col_nums*3,row_nums*3))
for i in range(row_nums):
    for j in range(col_nums):
        im = ax[i,j].imshow(gradient_[i*col_nums+j],cmap=cm.hot,vmin=mean(\
                    gradient_[gradient_!=0]),\
                   vmax=np.max(gradient_))
        ax[i,j].axis('off')
        
fig.subplots_adjust(right=0.84)
cbar_ax = fig.add_axes([0.85, 0.15, 0.02, 0.7])
fig.colorbar(im, cbar_ax)
Out[15]:
<matplotlib.colorbar.Colorbar at 0x7ff57f7eb438>

Metrics

PCC&MSE

In [16]:
def MSE_pcc(A,B,ax=None):
    mse =np.mean(((A - B)**2/B.var()))
    pcc = pearsonr(A.ravel(),B.ravel())[0]
    return mse,pcc
def analyze(predict,GT_STFT_test_spkr):
    samples = predict.shape[0]
    pcc = np.zeros([samples])
    mse = np.zeros([samples])
    for i in range(samples):
        mse[i], pcc[i] = MSE_pcc(predict[i],GT_STFT_test_spkr[i])
    fig,ax=plt.subplots(1,2,figsize=(16,4))
    ax[0].hist(mse,bins=25,color='b')
    ax[0].set_title('MSE: %g(%g)' %(np.round(mse.mean(),3),np.round(mse.std(),3)))
    ax[1].hist(pcc,bins=50,color='g')
    ax[1].set_title('PCC: %g(%g)' %(np.round(pcc.mean(),3),np.round(pcc.std(),3)))
    return mse,pcc

spec_gt = loadmat('spectrogram_GT.mat')['GT_STFT_test_spkr']
spec_pred = loadmat('spectrogram_prediction.mat')['pred_STFT_test']
mse,pcc = analyze(spec_pred,spec_gt)

STOI

In [17]:
stois = np.zeros([180])
for i in range(180):
    stois[i]=stoi(wave_gt[i*interval:(i+1)*interval], wave_pred[i*interval:(i+1)*interval], \
                      16000, extended=False)
In [18]:
fig,ax=plt.subplots(figsize=(8,4))
ax.hist(stois,bins=25,color='b')
ax.set_title('STOI: %g(%g)' %(np.round(stois.mean(),3),np.round(stois.std(),3)))
Out[18]:
Text(0.5, 1.0, 'STOI: 0.494(0.167)')